##
# Example solutions to the exercises in part 2
#
# It is useful if you try to solve them yourself first.
# Of course you're welcome to steal from google and your neighbour!
##







































##
# Example solution to 2a
#
# If you haven't yet solved this exercise, please do so first!
#
##

waveforms = []
times = []
data_dir = 'data'

# Get all subdirectories
subdirs = os.listdir(data_dir)

for i, subdir in enumerate(subdirs):
    
    # Get all .csv files in the subdirectory
    pattern = os.path.join(data_dir, subdir, '*.csv')
    files = glob.glob(pattern)
    
    for filename in tqdm(files, 
                         desc='Loading dir %d/%d' % (i, len(subdirs)-1)):
        # Extract the time from the filename
        without_dir = os.path.basename(filename)
        without_ext = os.path.splitext(without_dir)[0]
        times.append(float(without_ext))
        
        # Extract the waveform
        wv = np.loadtxt(filename)
        waveforms.append(wv)


























##
# Example solution for 2b
##


def count_above(data, threshold):
    in_interval = False
    n_above = 0
    for x in data:
        if not in_interval and x > threshold:
            in_interval = True
            n_above += 1
        if in_interval and x <= threshold:
            in_interval = False
    return n_above































##
# Example solution for 2c
#
# I fit the mean number of counts per 1000-sample waveform here,
# the photon detection rate is just 0.001 x the result I get.
#
##

bin_edges = np.linspace(0, 5, 100)
bin_centers = edges_to_centers(bin_edges)

n_waveforms, _ = np.histogram(data.voltage, bins=bin_edges)
tot_counts, _ = np.histogram(data.voltage, weights=data.counts, bins=bin_edges)

samples_per_waveform = 1000
mu = tot_counts / n_waveforms
sigma = np.sqrt(tot_counts) / n_waveforms

plt.errorbar(edges_to_centers(bin_edges),
             mu,
             yerr=sigma)
plt.show()

def rate_model(voltage, scale, v_threshold, power):
    return scale * np.clip(voltage - v_threshold, 0, float('inf'))**power

plt.plot(bin_centers, mu)
plt.plot(bin_centers, rate_model(bin_centers, 2, 1, 1.5))
plt.show()

popt, pcov = optimize.curve_fit(rate_model, bin_centers, mu, 
                                #sigma=np.clip(sigma, 0.01, float('inf')),
                                p0=[2, 1, 1.5])
plt.plot(bin_centers, mu)
plt.plot(bin_centers, rate_model(bin_centers, *popt))

plt.show()
for value, error, label in zip(popt, np.diag(pcov)**0.5, ['scale', 'v_t', 'p']):
    print("%s = %0.3f +- %0.3f" % (label, value, error))    

